{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BGE Explanation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we will go through BGE and BGE-v1.5's structure and how they generate embeddings."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Installation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the required packages in your environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "%pip install -U transformers FlagEmbedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Encode sentences"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To know how exactly a sentence is encoded, let's first load the tokenizer and model from HF transformers instead of FlagEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "import torch\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"BAAI/bge-base-en-v1.5\")\n",
    "model = AutoModel.from_pretrained(\"BAAI/bge-base-en-v1.5\")\n",
    "\n",
    "sentences = [\"embedding\", \"I love machine learning and nlp\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the following cell to check the model of bge-base-en-v1.5. It uses BERT-base as base model, with 12 encoder layers and hidden dimension of 768.\n",
    "\n",
    "Note that the corresponding models of BGE and BGE-v1.5 have same structures. For example, bge-base-en and bge-base-en-v1.5 have the same structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertModel(\n",
       "  (embeddings): BertEmbeddings(\n",
       "    (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "    (position_embeddings): Embedding(512, 768)\n",
       "    (token_type_embeddings): Embedding(2, 768)\n",
       "    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "  )\n",
       "  (encoder): BertEncoder(\n",
       "    (layer): ModuleList(\n",
       "      (0-11): 12 x BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (pooler): BertPooler(\n",
       "    (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "    (activation): Tanh()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's tokenize the sentences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[  101,  7861,  8270,  4667,   102,     0,     0,     0,     0],\n",
       "        [  101,  1045,  2293,  3698,  4083,  1998, 17953,  2361,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0],\n",
       "        [1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = tokenizer(\n",
    "    sentences, \n",
    "    padding=True, \n",
    "    truncation=True, \n",
    "    return_tensors='pt', \n",
    "    max_length=512\n",
    ")\n",
    "inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the results, we can see that each sentence begins with token 101 and ends with 102, which are the `[CLS]` and `[SEP]` special token used in BERT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 9, 768])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "last_hidden_state = model(**inputs, return_dict=True).last_hidden_state\n",
    "last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we implement the pooling function, with two choices of using `[CLS]`'s last hidden state, or the mean pooling of the whole last hidden state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pooling(last_hidden_state: torch.Tensor, pooling_method='cls', attention_mask: torch.Tensor = None):\n",
    "    if pooling_method == 'cls':\n",
    "        return last_hidden_state[:, 0]\n",
    "    elif pooling_method == 'mean':\n",
    "        s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)\n",
    "        d = attention_mask.sum(dim=1, keepdim=True).float()\n",
    "        return s / d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Different from more commonly used mean pooling, BGE is trained to use the last hidden state of `[CLS]` as the sentence embedding: \n",
    "\n",
    "`sentence_embeddings = model_output[0][:, 0]`\n",
    "\n",
    "If you use mean pooling, there will be a significant decrease in performance. Therefore, make sure to use the correct method to obtain sentence vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 768])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embeddings = pooling(\n",
    "    last_hidden_state, \n",
    "    pooling_method='cls', \n",
    "    attention_mask=inputs['attention_mask']\n",
    ")\n",
    "embeddings.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Assembling them together, we get the whole encoding function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _encode(sentences, max_length=512, convert_to_numpy=True):\n",
    "\n",
    "    # handle the case of single sentence and a list of sentences\n",
    "    input_was_string = False\n",
    "    if isinstance(sentences, str):\n",
    "        sentences = [sentences]\n",
    "        input_was_string = True\n",
    "\n",
    "    inputs = tokenizer(\n",
    "        sentences, \n",
    "        padding=True, \n",
    "        truncation=True, \n",
    "        return_tensors='pt', \n",
    "        max_length=max_length\n",
    "    )\n",
    "\n",
    "    last_hidden_state = model(**inputs, return_dict=True).last_hidden_state\n",
    "    \n",
    "    embeddings = pooling(\n",
    "        last_hidden_state, \n",
    "        pooling_method='cls', \n",
    "        attention_mask=inputs['attention_mask']\n",
    "    )\n",
    "\n",
    "    # normalize the embedding vectors\n",
    "    embeddings = torch.nn.functional.normalize(embeddings, dim=-1)\n",
    "\n",
    "    # convert to numpy if needed\n",
    "    if convert_to_numpy:\n",
    "        embeddings = embeddings.detach().numpy()\n",
    "\n",
    "    return embeddings[0] if input_was_string else embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Comparison"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's run the function we wrote to get the embeddings of the two sentences:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embeddings:\n",
      "[[ 1.4549762e-02 -9.6840411e-03  3.7761475e-03 ... -8.5092714e-04\n",
      "   2.8417887e-02  6.3214332e-02]\n",
      " [ 3.3924331e-05 -3.2998275e-03  1.7206438e-02 ...  3.5703944e-03\n",
      "   1.8721525e-02 -2.0371782e-02]]\n",
      "Similarity scores:\n",
      "[[0.9999997 0.6077381]\n",
      " [0.6077381 0.9999999]]\n"
     ]
    }
   ],
   "source": [
    "embeddings = _encode(sentences)\n",
    "print(f\"Embeddings:\\n{embeddings}\")\n",
    "\n",
    "scores = embeddings @ embeddings.T\n",
    "print(f\"Similarity scores:\\n{scores}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, run the API provided in FlagEmbedding:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embeddings:\n",
      "[[ 1.4549762e-02 -9.6840411e-03  3.7761475e-03 ... -8.5092714e-04\n",
      "   2.8417887e-02  6.3214332e-02]\n",
      " [ 3.3924331e-05 -3.2998275e-03  1.7206438e-02 ...  3.5703944e-03\n",
      "   1.8721525e-02 -2.0371782e-02]]\n",
      "Similarity scores:\n",
      "[[0.9999997 0.6077381]\n",
      " [0.6077381 0.9999999]]\n"
     ]
    }
   ],
   "source": [
    "from FlagEmbedding import FlagModel\n",
    "\n",
    "model = FlagModel('BAAI/bge-base-en-v1.5')\n",
    "\n",
    "embeddings = model.encode(sentences)\n",
    "print(f\"Embeddings:\\n{embeddings}\")\n",
    "\n",
    "scores = embeddings @ embeddings.T\n",
    "print(f\"Similarity scores:\\n{scores}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we expect, the two encoding functions return exactly the same results. The full implementation in FlagEmbedding handles large datasets by batching and contains GPU support and parallelization. Feel free to check the [source code](https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/inference/embedder/encoder_only/base.py) for more details."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
